Pytorch 显存节约

关于Pytorch 训练时显存ok,但是load checkpoints时显存out of memory问题

这个问题主要是由于一下几点:

  • 在load时先将checkpoints load到了gpu上,再load到model的地址,这样中间就多了一次存储。

  • 模型在load 之前就使用cuda()放在了gpu上,这样也会造成空间使冗余的情况

解决方式:

使用先load 在cpu上,然后load到model地址,最后push到gpu上的操作。

1
2
3
4
cpts = torch.load(os.path.join(checkponits_dir, "model.pth"), map_location='cpu')

model.load_state_dict(cpts)
model = model.cuda()